{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training a neural network with DALI and JAX\n",
    "\n",
    "This simple example shows how to train a neural network implemented in JAX with DALI pipelines. It builds on MNIST training example from JAX codebase that can be found [here](https://github.com/google/jax/blob/jax-v0.4.13/examples/mnist_classifier_fromscratch.py).\n",
    "\n",
    "We will use MNIST in Caffe2 format from [DALI_extra](https://github.com/NVIDIA/DALI_extra)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-28T07:43:41.850101Z",
     "iopub.status.busy": "2023-07-28T07:43:41.849672Z",
     "iopub.status.idle": "2023-07-28T07:43:41.853520Z",
     "shell.execute_reply": "2023-07-28T07:43:41.852990Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "training_data_path = os.path.join(\n",
    "    os.environ[\"DALI_EXTRA_PATH\"], \"db/MNIST/training/\"\n",
    ")\n",
    "validation_data_path = os.path.join(\n",
    "    os.environ[\"DALI_EXTRA_PATH\"], \"db/MNIST/testing/\"\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First step is to create a definition function that will later be used to create instances of DALI iterators. It defines all steps of the preprocessing. \n",
    "\n",
    "In this simple example we have `fn.readers.caffe2` for reading data in Caffe2 format, `fn.decoders.image` for image decoding, `fn.crop_mirror_normalize` used to normalize the images and `fn.reshape` to adjust the shape of the output tensors. We also move the labels from the CPU to the GPU memory with `labels.gpu()`. Our model expects labels to be in one-hot encoding, so we use `fn.one_hot` to convert them.\n",
    "\n",
    "This example focuses on how to use DALI to train a model defined in JAX. For more information on DALI and JAX integration look into [Getting started with JAX and DALI](jax-getting_started.ipynb) and [pipeline documentation](../../../pipeline.rst)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-28T07:43:41.855441Z",
     "iopub.status.busy": "2023-07-28T07:43:41.855301Z",
     "iopub.status.idle": "2023-07-28T07:43:41.986406Z",
     "shell.execute_reply": "2023-07-28T07:43:41.985500Z"
    }
   },
   "outputs": [],
   "source": [
    "from nvidia.dali.plugin.jax import data_iterator\n",
    "import nvidia.dali.fn as fn\n",
    "import nvidia.dali.types as types\n",
    "\n",
    "\n",
    "batch_size = 200\n",
    "image_size = 28\n",
    "num_classes = 10\n",
    "\n",
    "\n",
    "@data_iterator(output_map=[\"images\", \"labels\"], reader_name=\"caffe2_reader\")\n",
    "def mnist_iterator(data_path, random_shuffle):\n",
    "    jpegs, labels = fn.readers.caffe2(\n",
    "        path=data_path, random_shuffle=random_shuffle, name=\"caffe2_reader\"\n",
    "    )\n",
    "    images = fn.decoders.image(jpegs, device=\"mixed\", output_type=types.GRAY)\n",
    "    images = fn.crop_mirror_normalize(\n",
    "        images, dtype=types.FLOAT, std=[255.0], output_layout=\"CHW\"\n",
    "    )\n",
    "    images = fn.reshape(images, shape=[image_size * image_size])\n",
    "\n",
    "    labels = labels.gpu()\n",
    "\n",
    "    if random_shuffle:\n",
    "        labels = fn.one_hot(labels, num_classes=num_classes)\n",
    "\n",
    "    return images, labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we use the function to create DALI iterators for training and validation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-28T07:43:41.989183Z",
     "iopub.status.busy": "2023-07-28T07:43:41.988964Z",
     "iopub.status.idle": "2023-07-28T07:43:42.104446Z",
     "shell.execute_reply": "2023-07-28T07:43:42.103668Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating iterators\n",
      "<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7f2894462ef0>\n",
      "<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7f28944634c0>\n"
     ]
    }
   ],
   "source": [
    "print(\"Creating iterators\")\n",
    "\n",
    "training_iterator = mnist_iterator(\n",
    "    data_path=training_data_path, random_shuffle=True, batch_size=batch_size\n",
    ")\n",
    "\n",
    "validation_iterator = mnist_iterator(\n",
    "    data_path=validation_data_path, random_shuffle=False, batch_size=batch_size\n",
    ")\n",
    "\n",
    "print(training_iterator)\n",
    "print(validation_iterator)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the setup above, DALI iterators are ready for the training. \n",
    "\n",
    "Finally, we import training utilities implemented in JAX. `init_model` will create the model instance and initialize its parameters. In this simple example it is a MLP model with two hidden layers. `update` performs one iteration of the training. `accuracy` is a helper function to run validation after each epoch on the test set and get current accuracy of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-28T07:43:43.559575Z",
     "iopub.status.busy": "2023-07-28T07:43:43.559420Z",
     "iopub.status.idle": "2023-07-28T07:43:43.618221Z",
     "shell.execute_reply": "2023-07-28T07:43:43.617532Z"
    }
   },
   "outputs": [],
   "source": [
    "from model import init_model, update, accuracy"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At this point, everything is ready to run the training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-28T07:43:43.622376Z",
     "iopub.status.busy": "2023-07-28T07:43:43.621205Z",
     "iopub.status.idle": "2023-07-28T07:43:58.016073Z",
     "shell.execute_reply": "2023-07-28T07:43:58.015333Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting training\n",
      "Epoch 0 sec\n",
      "Test set accuracy 0.67330002784729\n",
      "Epoch 1 sec\n",
      "Test set accuracy 0.7855000495910645\n",
      "Epoch 2 sec\n",
      "Test set accuracy 0.8251000642776489\n",
      "Epoch 3 sec\n",
      "Test set accuracy 0.8469000458717346\n",
      "Epoch 4 sec\n",
      "Test set accuracy 0.8616000413894653\n"
     ]
    }
   ],
   "source": [
    "print(\"Starting training\")\n",
    "\n",
    "model = init_model()\n",
    "num_epochs = 5\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    for batch in training_iterator:\n",
    "        model = update(model, batch)\n",
    "\n",
    "    test_acc = accuracy(model, validation_iterator)\n",
    "    print(f\"Epoch {epoch} sec\")\n",
    "    print(f\"Test set accuracy {test_acc}\")"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Raw Cell Format",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
